{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# launch XVFB if you run on a server\n",
    "import os\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
    "    !bash ../xvfb start\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's make a TRPO!\n",
    "\n",
    "In this notebook we will write the code of the one Trust Region Policy Optimization.\n",
    "As usually, it contains a few different parts which we are going to reproduce.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import theano\n",
    "import theano.tensor as T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Observation Space Box(6,)\n",
      "Action Space Discrete(3)\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "env = gym.make(\"Acrobot-v1\")\n",
    "env.reset()\n",
    "\n",
    "observation_shape = env.observation_space.shape\n",
    "n_actions = env.action_space.n\n",
    "print(\"Observation Space\", env.observation_space)\n",
    "print(\"Action Space\", env.action_space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f83846e0240>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAAD8CAYAAABgtYFHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADktJREFUeJzt3W2opGd9x/Hvr5sHbRXXJMew7K5sxKXoizZmD3ElpdhE\nS0zFzYsEIlKXsLDQWlAs2E0LLUJfaF+YIBR1aaRrUZPUB7KEtDZsEkpfGHPWPLuNOUpqDhvclTzY\nItpG/30x19Fx9yTn2j0zZ2Z2vx8Y5rqv+7pn/hNmf7nue64zk6pCklbzG5MuQNJsMCwkdTEsJHUx\nLCR1MSwkdTEsJHUZS1gkuTrJk0kWk+wbx3NIWl8Z9TqLJBuA7wLvBpaAB4H3V9V3RvpEktbVOGYW\nlwOLVfX9qvpf4DZg1xieR9I6OmcMj7kZeGZoewl4+ysdcNFFF9W2bdvGUIqkZYcPH/5RVc2d7vHj\nCIus0HfSuU6SvcBegDe+8Y0sLCyMoRRJy5L811qOH8dpyBKwdWh7C3D0xEFVtb+q5qtqfm7utMNO\n0joZR1g8CGxPckmS84AbgINjeB5J62jkpyFV9VKSPwO+AWwAPl9VT4z6eSStr3Fcs6Cq7gbuHsdj\nS5oMV3BK6mJYSOpiWEjqYlhI6mJYSOpiWEjqYlhI6mJYSOpiWEjqYlhI6mJYSOpiWEjqYlhI6mJY\nSOpiWEjqYlhI6mJYSOpiWEjqYlhI6mJYSOpiWEjqYlhI6mJYSOpiWEjqYlhI6mJYSOpiWEjqYlhI\n6mJYSOpiWEjqYlhI6mJYSOpiWEjqYlhI6mJYSOqyalgk+XySY0keH+q7IMk9SZ5q969v/Uny6SSL\nSR5Nctk4i5e0fnpmFv8IXH1C3z7gUFVtBw61bYD3ANvbbS/wmdGUKWnSVg2Lqvp34LkTuncBB1r7\nAHDtUP8XauCbwMYkm0ZVrKTJOd1rFhdX1bMA7f4NrX8z8MzQuKXWd5Ike5MsJFk4fvz4aZYhab2M\n+gJnVuirlQZW1f6qmq+q+bm5uRGXIWnUTjcsfrh8etHuj7X+JWDr0LgtwNHTL0/StDjdsDgI7G7t\n3cCdQ/0fbJ+K7AReXD5dkTTbzlltQJIvA+8ELkqyBPwN8AngjiR7gB8A17fhdwPXAIvAT4Abx1Cz\npAlYNSyq6v0vs+uqFcYW8KG1FiVp+riCU1IXw0JSF8NCUhfDQlIXw0JSF8NCUhfDQlIXw0JSF8NC\nUpcMFl1OuIhk8kVIZ77DVTV/ugevutx7PezYsYOFhYVJlyGd0ZKVvkGin6chkroYFpK6GBaSuhgW\nkroYFpK6GBaSuhgWkroYFpK6GBaSuhgWkroYFpK6GBaSuhgWkroYFpK6GBaSuhgWkroYFpK6GBaS\nuhgWkroYFpK6GBaSuhgWkroYFpK6rBoWSbYmuS/JkSRPJPlw678gyT1Jnmr3r2/9SfLpJItJHk1y\n2bhfhKTx65lZvAT8eVW9BdgJfCjJW4F9wKGq2g4catsA7wG2t9te4DMjr1rSuls1LKrq2ar6dmv/\nN3AE2AzsAg60YQeAa1t7F/CFGvgmsDHJppFXLmldndI1iyTbgLcBDwAXV9WzMAgU4A1t2GbgmaHD\nllqfpBnWHRZJXgN8FfhIVf34lYau0HfSDx8n2ZtkIcnC8ePHe8uQNCFdYZHkXAZB8cWq+lrr/uHy\n6UW7P9b6l4CtQ4dvAY6e+JhVtb+q5qtqfm5u7nTrl7ROej4NCXArcKSqPjW06yCwu7V3A3cO9X+w\nfSqyE3hx+XRF0uw6p2PMFcAfA48lebj1/SXwCeCOJHuAHwDXt313A9cAi8BPgBtHWrGkiVg1LKrq\nP1j5OgTAVSuML+BDa6xL0pRxBaekLoaFpC6GhaQuhoWkLoaFpC6GhaQuhoWkLoaFpC6GhaQuhoWk\nLoaFpC49f0gm/dLhw7/+Z0I7dpz0VSU6QzmzULcTg+Ll+nRmMizU5ZVCwcA4OxgWWlVPGBgYZz7D\nQlIXw0JSF8NCUhfDQquaZ2EkYzTbDAt1eaUwMCjODoaFuq0UCgbF2cMVnDolhsPZy5mFpC6GhaQu\nhoWkLoaFpC6GhaQuhoWkLoaFpC6GhaQuhoWkLoaFpC6GhaQuhoWkLoaFpC6rhkWSVyX5VpJHkjyR\n5OOt/5IkDyR5KsntSc5r/ee37cW2f9t4X4Kk9dAzs/gZcGVV/S5wKXB1kp3AJ4Gbq2o78Dywp43f\nAzxfVW8Gbm7jJM24VcOiBv6nbZ7bbgVcCXyl9R8Arm3tXW2btv+qJH5PvDTjuq5ZJNmQ5GHgGHAP\n8D3ghap6qQ1ZAja39mbgGYC2/0XgwhUec2+ShSQLx48fX9urkDR2XWFRVT+vqkuBLcDlwFtWGtbu\nV5pFnPSDmFW1v6rmq2p+bm6ut15JE3JKn4ZU1QvA/cBOYGOS5a/l2wIcbe0lYCtA2/864LlRFCtp\ncno+DZlLsrG1Xw28CzgC3Adc14btBu5s7YNtm7b/3qryp7alGdfzhb2bgANJNjAIlzuq6q4k3wFu\nS/K3wEPArW38rcA/JVlkMKO4YQx1S1pnq4ZFVT0KvG2F/u8zuH5xYv9PgetHUp2kqeEKTkldDAtJ\nXQwLrap27Jh0CZoChoWkLoaFpC6GhaQuhoWkLoaFpC6GhaQuhoWkLoaFpC6GhUYmhw9PugSNkWEh\nqYthIamLYSGpi2EhqYthodOywDwLzE+6DK0jw0Kn5MSQMDDOHoaFur1cMBgYZwfDQl1WCwQD48xn\nWGhkDIwzm2EhqYthIamLYaGRmWdh0iVojAwLdVktCAyKM59hoTUzKM4OPb91KgGGwtnOmYWkLoaF\npC6GhaQuhoWkLoaFpC6GhaQuhoWkLt1hkWRDkoeS3NW2L0nyQJKnktye5LzWf37bXmz7t42ndEnr\n6VRmFh8GjgxtfxK4uaq2A88De1r/HuD5qnozcHMbJ2nGdYVFki3AHwH/0LYDXAl8pQ05AFzb2rva\nNm3/VW28pBnWO7O4BfgY8Iu2fSHwQlW91LaXgM2tvRl4BqDtf7GNlzTDVg2LJO8FjlXV8G/TrTRT\nqI59w4+7N8lCkoXjx493FStpcnpmFlcA70vyNHAbg9OPW4CNSZb/EG0LcLS1l4CtAG3/64DnTnzQ\nqtpfVfNVNT83N7emFyFp/FYNi6q6qaq2VNU24Abg3qr6AHAfcF0bthu4s7UPtm3a/nur6qSZhaTZ\nspZ1Fn8BfDTJIoNrEre2/luBC1v/R4F9aytR0jQ4pe+zqKr7gftb+/vA5SuM+Slw/QhqkzRFXMEp\nqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGp\ni2EhqYthIamLYSGpi2GhLrVjx6RL0IQZFpK6GBaSuhgWkroYFpK6GBaSuhgWkroYFpK6GBYaqRw+\nPOkSNCaGhaQup/TDyDq7uYrz7ObMQlIXw0JSF8NCUhfDQlIXw0JSF8NCUpeusEjydJLHkjycZKH1\nXZDkniRPtfvXt/4k+XSSxSSPJrlsnC9A0vo4lZnFH1TVpVU137b3AYeqajtwqG0DvAfY3m57gc+M\nqlhJk7OW05BdwIHWPgBcO9T/hRr4JrAxyaY1PI+kKdC7grOAf0tSwOeqaj9wcVU9C1BVzyZ5Qxu7\nGXhm6Nil1vfs8AMm2ctg5gHwsySPn+ZrmISLgB9NuohOs1QrzFa9s1QrwG+v5eDesLiiqo62QLgn\nyX++wtis0FcndQwCZz9AkoWh05upN0v1zlKtMFv1zlKtMKh3Lcd3nYZU1dF2fwz4OnA58MPl04t2\nf6wNXwK2Dh2+BTi6liIlTd6qYZHkt5K8drkN/CHwOHAQ2N2G7QbubO2DwAfbpyI7gReXT1ckza6e\n05CLga8nWR7/par61yQPAnck2QP8ALi+jb8buAZYBH4C3NjxHPtPtfAJm6V6Z6lWmK16Z6lWWGO9\nqTrpcoIkncQVnJK6TDwsklyd5Mm24nPf6keMvZ7PJzk2/FHuNK9WTbI1yX1JjiR5IsmHp7XmJK9K\n8q0kj7RaP976L0nyQKv19iTntf7z2/Zi279tvWodqnlDkoeS3DUDtY53pXVVTewGbAC+B7wJOA94\nBHjrhGv6feAy4PGhvr8D9rX2PuCTrX0N8C8MPi7eCTwwgXo3AZe19muB7wJvncaa23O+prXPBR5o\nNdwB3ND6Pwv8SWv/KfDZ1r4BuH0C/30/CnwJuKttT3OtTwMXndA3svfBur6YFV7cO4BvDG3fBNw0\nyZpaHdtOCIsngU2tvQl4srU/B7x/pXETrP1O4N3TXjPwm8C3gbczWNh0zonvCeAbwDta+5w2LutY\n4xYGf8pwJXBX+4c1lbW2510pLEb2Ppj0acjLrfacNr+2WhVYbbXqRLSp79sY/B97Kmtu0/qHGazL\nuYfBzPKFqnpphXp+WWvb/yJw4XrVCtwCfAz4Rdu+kOmtFX610vpwWyENI3wfTPoLe7tWe06xqak/\nyWuArwIfqaoft4+6Vxy6Qt+61VxVPwcuTbKRwQK/t7xCPROrNcl7gWNVdTjJOzvqmYb3wshXWg+b\n9MxiVlZ7TvVq1STnMgiKL1bV11r3VNdcVS8A9zM4X96YZPl/XMP1/LLWtv91wHPrVOIVwPuSPA3c\nxuBU5JYprRUY/0rrSYfFg8D2doX5PAYXhg5OuKaVTO1q1QymELcCR6rqU0O7pq7mJHNtRkGSVwPv\nAo4A9wHXvUyty6/hOuDeaifY41ZVN1XVlqraxuB9eW9VfWAaa4V1Wmm9nhdgXuaizDUMruB/D/ir\nKajnywz+Qvb/GKTvHgbnnoeAp9r9BW1sgL9vtT8GzE+g3t9jMH18FHi43a6ZxpqB3wEearU+Dvx1\n638T8C0Gq37/GTi/9b+qbS+2/W+a0Hvinfzq05CprLXV9Ui7PbH8b2mU7wNXcErqMunTEEkzwrCQ\n1MWwkNTFsJDUxbCQ1MWwkNTFsJDUxbCQ1OX/ASKF9tx4Ki+cAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f838ddca8d0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.imshow(env.render('rgb_array'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Defining a network\n",
    "\n",
    "With all it's complexity, at it's core TRPO is yet another policy gradient method. \n",
    "\n",
    "This essentially means we're actually training a stochastic policy $ \\pi_\\theta(a|s) $. \n",
    "\n",
    "And yes, it's gonna be a neural network. So let's start by defining one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input tensors\n",
    "observations = T.matrix(name=\"obs\")\n",
    "actions = T.ivector(name=\"action\")\n",
    "cummulative_returns = T.vector(name=\"G = r + gamma*r' + gamma^2*r'' + ...\")\n",
    "old_probs = T.matrix(name=\"action probabilities from previous iteration\")\n",
    "\n",
    "all_inputs = [observations, actions, cummulative_returns, old_probs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create neural network.\n",
    "from lasagne.layers import *\n",
    "\n",
    "nn = InputLayer((None,)+observation_shape, input_var=observations)\n",
    "\n",
    "<your network here >\n",
    "\n",
    "policy = <layer that predicts action probabilities >\n",
    "\n",
    "probs = get_output(policy)\n",
    "\n",
    "weights = get_all_params(policy, trainable=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Actions and rollouts\n",
    "\n",
    "In this section, we'll define functions that take actions $ a \\sim \\pi_\\theta(a|s) $ and rollouts $ <s_0,a_0,s_1,a_1,s_2,a_2,...s_n,a_n> $."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compile function\n",
    "get_policy = theano.function([observations], probs, allow_input_downcast=True)\n",
    "\n",
    "\n",
    "def act(obs, sample=True):\n",
    "    \"\"\"\n",
    "    Samples action from policy distribution (sample = True) or takes most likely action (sample = False)\n",
    "    :param: obs - single observation vector\n",
    "    :param sample: if True, samples from \\pi, otherwise takes most likely action\n",
    "    :returns: action (single integer) and probabilities for all actions\n",
    "    \"\"\"\n",
    "\n",
    "    policy = get_policy([obs])[0]\n",
    "\n",
    "    if sample:\n",
    "        action = int(np.random.choice(n_actions, p=policy))\n",
    "    else:\n",
    "        action = int(np.argmax(policy))\n",
    "\n",
    "    return action, policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# demo\n",
    "print(\"sampled:\", [act(env.reset()) for _ in range(100)])\n",
    "print(\"greedy:\", [act(env.reset(), sample=False) for _ in range(100)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute cummulative reward just like you did in vanilla REINFORCE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.signal\n",
    "\n",
    "\n",
    "def get_cummulative_returns(r, gamma=1):\n",
    "    \"\"\"\n",
    "    Computes cummulative discounted rewards given immediate rewards\n",
    "    G_i = r_i + gamma*r_{i+1} + gamma^2*r_{i+2} + ...\n",
    "    Also known as R(s,a).\n",
    "    \"\"\"\n",
    "    r = np.array(r)\n",
    "    assert r.ndim >= 1\n",
    "    return scipy.signal.lfilter([1], [1, -gamma], r[::-1], axis=0)[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# simple demo on rewards [0,0,1,0,0,1]\n",
    "get_cummulative_returns([0, 0, 1, 0, 0, 1], gamma=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Rollout**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rollout(env, act, max_pathlength=2500, n_timesteps=50000):\n",
    "    \"\"\"\n",
    "    Generate rollouts for training.\n",
    "    :param: env - environment in which we will make actions to generate rollouts.\n",
    "    :param: act - the function that can return policy and action given observation.\n",
    "    :param: max_pathlength - maximum size of one path that we generate.\n",
    "    :param: n_timesteps - total sum of sizes of all pathes we generate.\n",
    "    \"\"\"\n",
    "    paths = []\n",
    "\n",
    "    total_timesteps = 0\n",
    "    while total_timesteps < n_timesteps:\n",
    "        obervations, actions, rewards, action_probs = [], [], [], []\n",
    "        obervation = env.reset()\n",
    "        for _ in range(max_pathlength):\n",
    "            action, policy = act(obervation)\n",
    "            obervations.append(obervation)\n",
    "            actions.append(action)\n",
    "            action_probs.append(policy)\n",
    "            obervation, reward, done, _ = env.step(action)\n",
    "            rewards.append(reward)\n",
    "            total_timesteps += 1\n",
    "            if done or total_timesteps == n_timesteps:\n",
    "                path = {\"observations\": np.array(obervations),\n",
    "                        \"policy\": np.array(action_probs),\n",
    "                        \"actions\": np.array(actions),\n",
    "                        \"rewards\": np.array(rewards),\n",
    "                        \"cumulative_returns\": get_cummulative_returns(rewards),\n",
    "                        }\n",
    "                paths.append(path)\n",
    "                break\n",
    "    return paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = rollout(env, act, max_pathlength=5, n_timesteps=100)\n",
    "print(paths[-1])\n",
    "assert (paths[0]['policy'].shape == (5, n_actions))\n",
    "assert (paths[0]['cumulative_returns'].shape == (5,))\n",
    "assert (paths[0]['rewards'].shape == (5,))\n",
    "assert (paths[0]['observations'].shape == (5,)+observation_shape)\n",
    "assert (paths[0]['actions'].shape == (5,))\n",
    "print('It\\'s ok')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: loss functions\n",
    "\n",
    "Now let's define the loss functions and constraints for actual TRPO training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The surrogate reward should be\n",
    "$$J_{surr}= {1 \\over N} \\sum\\limits_{i=0}^N \\frac{\\pi_{\\theta}(s_i, a_i)}{\\pi_{\\theta_{old}}(s_i, a_i)}A_{\\theta_{old}(s_i, a_i)}$$\n",
    "\n",
    "For simplicity, let's use cummulative returns instead of advantage for now:\n",
    "$$J'_{surr}= {1 \\over N} \\sum\\limits_{i=0}^N \\frac{\\pi_{\\theta}(s_i, a_i)}{\\pi_{\\theta_{old}}(s_i, a_i)}G_{\\theta_{old}(s_i, a_i)}$$\n",
    "\n",
    "Or alternatively, minimize the surrogate loss:\n",
    "$$ L_{surr} = - J'_{surr} $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select probabilities of chosen actions\n",
    "batch_size = actions.shape[0]\n",
    "\n",
    "probs_for_actions = probs[T.arange(batch_size), actions]\n",
    "old_probs_for_actions = old_probs[T.arange(batch_size), actions]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute surrogate loss: negative importance-sampled policy gradient\n",
    "\n",
    "L_surr = <compute surrogate loss, aka _negative_ importance-sampled policy gradient >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute and return surrogate policy gradient\n",
    "\n",
    "\n",
    "def get_flat_gradient(loss, var_list):\n",
    "    \"\"\"gradient of loss wrt var_list flattened into a large vector\"\"\"\n",
    "    grads = T.grad(loss, var_list)\n",
    "    return T.concatenate([grad.ravel() for grad in grads])\n",
    "\n",
    "\n",
    "get_surrogate_gradients = theano.function(all_inputs, get_flat_gradient(L_surr, weights),\n",
    "                                          allow_input_downcast=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can ascend these gradients as long as our $pi_\\theta(a|s)$ satisfies the constraint\n",
    "$$E_{s,\\pi_{\\Theta_{t}}}\\Big[KL(\\pi(\\Theta_{t}, s) \\:||\\:\\pi(\\Theta_{t+1}, s))\\Big]< \\alpha$$\n",
    "\n",
    "\n",
    "where\n",
    "\n",
    "$$KL(p||q) = E _p log({p \\over q})$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute Kullback-Leibler divergence (see formula above)\n",
    "# Note: you need to sum KL and entropy over all actions, not just the ones agent took\n",
    "old_log_probs = T.log(old_probs + 1e-10)\n",
    "\n",
    "kl = <cumpute kullback-leibler as per formula above >\n",
    "\n",
    "# Compute policy entropy\n",
    "entropy = <compute policy entropy. Don't forget the sign!>\n",
    "\n",
    "compute_losses = theano.function(all_inputs, [L_surr, kl, entropy],\n",
    "                                 allow_input_downcast=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Linear search**\n",
    "\n",
    "TRPO in its core involves ascending surrogate policy gradient constrained by KL divergence. \n",
    "\n",
    "In order to enforce this constraint, we're gonna use linesearch. You can find out more about it [here](https://en.wikipedia.org/wiki/Linear_search)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linesearch(f, x, fullstep, max_kl):\n",
    "    \"\"\"\n",
    "    Linesearch finds the best parameters of neural networks in the direction of fullstep contrainted by KL divergence.\n",
    "    :param: f - function that returns loss, kl and arbitrary third component.\n",
    "    :param: x - old parameters of neural network.\n",
    "    :param: fullstep - direction in which we make search.\n",
    "    :param: max_kl - constraint of KL divergence.\n",
    "    :returns:\n",
    "    \"\"\"\n",
    "    max_backtracks = 10\n",
    "    loss, _, _ = f(x)\n",
    "    for stepfrac in .5 ** np.arange(max_backtracks):\n",
    "        xnew = x + stepfrac * fullstep\n",
    "        new_loss, kl, _ = f(xnew)\n",
    "        actual_improve = new_loss - loss\n",
    "        if kl <= max_kl and actual_improve < 0:\n",
    "            x = xnew\n",
    "            loss = new_loss\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4: training\n",
    "In this section we construct rest parts of our computational graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def slice_vector(vector, shapes):\n",
    "    \"\"\"\n",
    "    Slices symbolic vector into several symbolic tensors of given shapes.\n",
    "    Auxilary function used to un-flatten gradients, tangents etc.\n",
    "    :param vector: 1-dimensional symbolic vector\n",
    "    :param shapes: list or tuple of shapes (list, tuple or symbolic)\n",
    "    :returns: list of symbolic tensors of given shapes\n",
    "    \"\"\"\n",
    "    assert vector.ndim == 1, \"vector must be 1-dimensional\"\n",
    "    start = 0\n",
    "    tensors = []\n",
    "    for shape in shapes:\n",
    "        size = T.prod(shape)\n",
    "        tensor = vector[start:(start + size)].reshape(shape)\n",
    "        tensors.append(tensor)\n",
    "        start += size\n",
    "    return tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conjugate_grad_intermediate_vector = T.vector(\n",
    "    \"intermediate grad in conjugate_gradient\")\n",
    "\n",
    "# slice flat_tangent into chunks for each weight\n",
    "weight_shapes = [var.get_value().shape for var in weights]\n",
    "tangents = slice_vector(conjugate_grad_intermediate_vector, weight_shapes)\n",
    "\n",
    "# KL divergence where first arg is fixed\n",
    "from theano.gradient import disconnected_grad as const\n",
    "kl_firstfixed = (const(probs) * (const(T.log(probs)) -\n",
    "                                 T.log(probs))).sum(axis=-1).mean()\n",
    "\n",
    "# compute fisher information matrix (used for conjugate gradients and to estimate KL)\n",
    "gradients = T.grad(kl_firstfixed, weights)\n",
    "gradient_vector_product = [T.sum(g * t) for (g, t) in zip(gradients, tangents)]\n",
    "\n",
    "fisher_vector_product = get_flat_gradient(\n",
    "    sum(gradient_vector_product), weights)\n",
    "\n",
    "compute_fisher_vector_product = theano.function(\n",
    "    [observations, conjugate_grad_intermediate_vector], fisher_vector_product)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TRPO helpers\n",
    "\n",
    "Here we define a few helper functions used in the main TRPO loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Conjugate gradients**\n",
    "\n",
    "Since TRPO includes contrainted optimization, we will need to solve Ax=b using conjugate gradients.\n",
    "\n",
    "In general, CG is an algorithm that solves Ax=b where A is positive-defined. A is Hessian matrix so A is positive-defined. You can find out more about them [here](https://en.wikipedia.org/wiki/Conjugate_gradient_method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy.linalg import inv\n",
    "\n",
    "\n",
    "def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10):\n",
    "    \"\"\"\n",
    "    This method solves system of equation Ax=b using iterative method called conjugate gradients\n",
    "    :f_Ax: function that returns Ax\n",
    "    :b: targets for Ax\n",
    "    :cg_iters: how many iterations this method should do\n",
    "    :residual_tol: epsilon for stability\n",
    "    \"\"\"\n",
    "    p = b.copy()\n",
    "    r = b.copy()\n",
    "    x = np.zeros_like(b)\n",
    "    rdotr = r.dot(r)\n",
    "    for i in range(cg_iters):\n",
    "        z = f_Ax(p)\n",
    "        v = rdotr / (p.dot(z) + 1e-8)\n",
    "        x += v * p\n",
    "        r -= v * z\n",
    "        newrdotr = r.dot(r)\n",
    "        mu = newrdotr / (rdotr + 1e-8)\n",
    "        p = r + mu * p\n",
    "        rdotr = newrdotr\n",
    "        if rdotr < residual_tol:\n",
    "            break\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This code validates conjugate gradients\n",
    "A = np.random.rand(8, 8)\n",
    "A = np.matmul(np.transpose(A), A)\n",
    "\n",
    "\n",
    "def f_Ax(x):\n",
    "    return np.matmul(A, x.reshape(-1, 1)).reshape(-1)\n",
    "\n",
    "\n",
    "b = np.random.rand(8)\n",
    "\n",
    "w = np.matmul(np.matmul(inv(np.matmul(np.transpose(A), A)),\n",
    "                        np.transpose(A)), b.reshape((-1, 1))).reshape(-1)\n",
    "print(w)\n",
    "print(conjugate_gradient(f_Ax, b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compile a function that exports network weights as a vector\n",
    "flat_weights = T.concatenate([var.ravel() for var in weights])\n",
    "get_flat_weights = theano.function([], flat_weights)\n",
    "\n",
    "# ... and another function that imports vector back into network weights\n",
    "flat_weights_placeholder = T.vector(\"flattened weights\")\n",
    "assigns = slice_vector(flat_weights_placeholder, weight_shapes)\n",
    "\n",
    "load_flat_weights = theano.function(\n",
    "    [flat_weights_placeholder], updates=dict(zip(weights, assigns)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Step 5: Main TRPO loop\n",
    "\n",
    "Here we will train our network!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from itertools import count\n",
    "from collections import OrderedDict\n",
    "\n",
    "# this is hyperparameter of TRPO. It controls how big KL divergence may be between old and new policy every step.\n",
    "max_kl = 0.01\n",
    "cg_damping = 0.1  # This parameters regularize addition to\n",
    "numeptotal = 0  # this is number of episodes that we played.\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "for i in count(1):\n",
    "\n",
    "    print(\"\\n********** Iteration %i ************\" % i)\n",
    "\n",
    "    # Generating paths.\n",
    "    print(\"Rollout\")\n",
    "    paths = rollout(env, act)\n",
    "    print(\"Made rollout\")\n",
    "\n",
    "    # Updating policy.\n",
    "    observations = np.concatenate([path[\"observations\"] for path in paths])\n",
    "    actions = np.concatenate([path[\"actions\"] for path in paths])\n",
    "    returns = np.concatenate([path[\"cumulative_returns\"] for path in paths])\n",
    "    old_probs = np.concatenate([path[\"policy\"] for path in paths])\n",
    "    inputs_batch = [observations, actions, returns, old_probs]\n",
    "\n",
    "    old_weights = get_flat_weights()\n",
    "\n",
    "    def fisher_vector_product(p):\n",
    "        \"\"\"gets intermediate grads (p) and computes fisher*vector \"\"\"\n",
    "        return compute_fisher_vector_product(observations, p) + cg_damping * p\n",
    "\n",
    "    flat_grad = get_surrogate_gradients(*inputs_batch)\n",
    "\n",
    "    stepdir = conjugate_gradient(fisher_vector_product, -flat_grad)\n",
    "    shs = .5 * stepdir.dot(fisher_vector_product(stepdir))\n",
    "    lm = np.sqrt(shs / max_kl)\n",
    "    fullstep = stepdir / lm\n",
    "\n",
    "    # Compute new weights with linesearch in the direction we found with CG\n",
    "\n",
    "    def losses_f(flat_weights):\n",
    "        load_flat_weights(flat_weights)\n",
    "        return compute_losses(*inputs_batch)\n",
    "\n",
    "    new_weights = linesearch(losses_f, old_weights, fullstep, max_kl)\n",
    "\n",
    "    load_flat_weights(new_weights)\n",
    "\n",
    "    # Report current progress\n",
    "    L_surr, kl, entropy = compute_losses(*inputs_batch)\n",
    "    episode_rewards = np.array([path[\"rewards\"].sum() for path in paths])\n",
    "\n",
    "    stats = OrderedDict()\n",
    "    numeptotal += len(episode_rewards)\n",
    "    stats[\"Total number of episodes\"] = numeptotal\n",
    "    stats[\"Average sum of rewards per episode\"] = episode_rewards.mean()\n",
    "    stats[\"Std of rewards per episode\"] = episode_rewards.std()\n",
    "    stats[\"Entropy\"] = entropy\n",
    "    stats[\"Time elapsed\"] = \"%.2f mins\" % ((time.time() - start_time)/60.)\n",
    "    stats[\"KL between old and new distribution\"] = kl\n",
    "    stats[\"Surrogate loss\"] = L_surr\n",
    "    for k, v in stats.items():\n",
    "        print(k + \": \" + \" \" * (40 - len(k)) + str(v))\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homework option I: better sampling (10+pts)\n",
    "\n",
    "In this section, you're invited to implement a better rollout strategy called _vine_.\n",
    "\n",
    "![img](https://s17.postimg.cc/i90chxgvj/vine.png)\n",
    "\n",
    "In most gym environments, you can actually backtrack by using states. You can find a wrapper that saves/loads states in [the mcts seminar](https://github.com/yandexdataschool/Practical_RL/blob/master/week10_planning/seminar_MCTS.ipynb).\n",
    "\n",
    "You can read more about in the [TRPO article](https://arxiv.org/abs/1502.05477) in section 5.2.\n",
    "\n",
    "The goal here is to implement such rollout policy (we recommend using tree data structure like in the seminar above).\n",
    "Then you can assign cummulative rewards similar to `get_cummulative_rewards`, but for a tree.\n",
    "\n",
    "__bonus task__ - parallelize samples using multiple cores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homework option II (10+pts)\n",
    "\n",
    "Let's use TRPO to train evil robots! (pick any of two)\n",
    "* [MuJoCo robots](https://gym.openai.com/envs#mujoco)\n",
    "* [Box2d robot](https://gym.openai.com/envs/BipedalWalker-v2)\n",
    "\n",
    "The catch here is that those environments have continuous action spaces. \n",
    "\n",
    "Luckily, TRPO is a policy gradient method, so it's gonna work for any parametric $\\pi_\\theta(a|s)$. We recommend starting with gaussian policy:\n",
    "\n",
    "$$\\pi_\\theta(a|s) = N(\\mu_\\theta(s),\\sigma^2_\\theta(s)) = {1 \\over \\sqrt { 2 \\pi {\\sigma^2}_\\theta(s) } } e^{ (a - \n",
    "\\mu_\\theta(s))^2 \\over 2 {\\sigma^2}_\\theta(s) } $$\n",
    "\n",
    "In the $\\sqrt { 2 \\pi {\\sigma^2}_\\theta(s) }$ clause, $\\pi$ means ~3.1415926, not agent's policy.\n",
    "\n",
    "This essentially means that you will need two output layers:\n",
    "* $\\mu_\\theta(s)$, a dense layer with linear activation\n",
    "* ${\\sigma^2}_\\theta(s)$, a dense layer with activation T.exp (to make it positive; like rho from bandits)\n",
    "\n",
    "For multidimensional actions, you can use fully factorized gaussian (basically a vector of gaussians).\n",
    "\n",
    "__bonus task__: compare performance of continuous action space method to action space discretization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
